{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "## Speedup when running PyTorch models on ONNX\n",
        "\n",
        "In this notebook, we show the speed-up on inference time that we achieve when migrating a Resnet18 from PyTorch to ONNX."
      ],
      "metadata": {
        "id": "7wKMZUbiLF-h"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install onnx -q\n",
        "!pip install onnxruntime -q"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "dPgEFGnAEF7k",
        "outputId": "6c1a0b62-6b67-4f27-ec92-4f57617a34bb"
      },
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m14.6/14.6 MB\u001b[0m \u001b[31m79.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.2/6.2 MB\u001b[0m \u001b[31m43.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m46.0/46.0 kB\u001b[0m \u001b[31m4.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m86.8/86.8 kB\u001b[0m \u001b[31m10.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import torch.onnx\n",
        "import torchvision.models as models\n",
        "\n",
        "# Initialize model\n",
        "model = models.resnet18(pretrained=True)\n",
        "\n",
        "# Set model to evaluation mode\n",
        "model.eval()\n",
        "\n",
        "# Define dummy input\n",
        "x = torch.randn(1, 3, 224, 224)\n",
        "\n",
        "# Export to ONNX\n",
        "torch.onnx.export(model, x, \"resnet18.onnx\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "W_n9EBCPEZhB",
        "outputId": "c40d02c7-bfa0-4ac1-8f9f-10db30bea193"
      },
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
            "  warnings.warn(\n",
            "/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.\n",
            "  warnings.warn(msg)\n",
            "Downloading: \"https://download.pytorch.org/models/resnet18-f37072fd.pth\" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth\n",
            "100%|██████████| 44.7M/44.7M [00:00<00:00, 58.0MB/s]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "\n",
        "### Inference time for PyTorch model\n",
        "\n"
      ],
      "metadata": {
        "id": "Rgg6ZbdfLbCy"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "%%time\n",
        "output = model(x)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "089oh1aIKSSi",
        "outputId": "66973c0e-b5fd-43ae-9b1d-f2b66515343b"
      },
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "CPU times: user 88.4 ms, sys: 516 µs, total: 88.9 ms\n",
            "Wall time: 92 ms\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iY1p1fRnEEej"
      },
      "outputs": [],
      "source": [
        "import onnxruntime as ort\n",
        "import numpy as np\n",
        "import torch\n",
        "\n",
        "# Initialize ONNX Runtime session\n",
        "ort_session = ort.InferenceSession(\"resnet18.onnx\")\n",
        "\n",
        "# Prepare input data (same shape as the dummy input used for exporting the model)\n",
        "input_name = ort_session.get_inputs()[0].name\n",
        "input_shape = (1, 3, 224, 224)\n",
        "input_data = np.random.randn(*input_shape).astype(np.float32)"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Inference time for ONNX model"
      ],
      "metadata": {
        "id": "oiUT_pwSLhQu"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "%%time\n",
        "# Run inference\n",
        "ort_inputs = {input_name: input_data}\n",
        "ort_outs = ort_session.run(None, ort_inputs)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "BMMV4KGWK--l",
        "outputId": "70122565-1baf-4b01-d5ee-c09253de9204"
      },
      "execution_count": 11,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "CPU times: user 59.8 ms, sys: 0 ns, total: 59.8 ms\n",
            "Wall time: 63.9 ms\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Extract output\n",
        "output_data = ort_outs[0]\n",
        "\n",
        "# Convert ONNX Runtime output to PyTorch tensor if needed\n",
        "output_tensor = torch.from_numpy(output_data)\n",
        "output_tensor.shape"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Qx9sLvAOK3aN",
        "outputId": "c933d134-029a-4e45-dc18-1f8954edd402"
      },
      "execution_count": 12,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "torch.Size([1, 1000])"
            ]
          },
          "metadata": {},
          "execution_count": 12
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "7BpPTyPIEcWw"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}